import os
import argparse
import random
import numpy as np
import matplotlib.pyplot as plt

from typing import List
from tqdm import tqdm

from utils import get_env, get_agent, normalize_reward, take_env_step, perturb_action, discretize_state, save_arrlc_model




# run num_trajectory trajectories and return the average return of them
def evaluate(agent, env, args) -> float:
    trajectory_rewards = []

    for i in range(args.num_eval_trajectory):
        cur_reward = 0
        state = env.reset()[0]
        done = False
        already_done = False
        for i_step in range(args.episode_steps_test):
            if done:
                if "FrozenLake" in args.env_name:
                    state = env.reset()[0]
                elif args.env_name in ["CliffWalking-v0", "CartPole-v1", "InvertedPendulum-v4", "MountainCar-v0"]:
                    already_done = True
                else:
                    raise ValueError()

            action = agent.take_action(discretize_state(state, args.env_name), i_step, is_train=False)
            action = perturb_action(args, action, action_dim)
            next_state, reward, done = take_env_step(env, action, already_done, state, args.env_name)
            cur_reward += reward
            state = next_state

        trajectory_rewards.append(cur_reward)


    return np.mean(trajectory_rewards), np.std(trajectory_rewards)





def train(agent, env, args) -> List:
    return_list = []
    already_done = True
    best_model = None
    best_results = -float("inf")
    with tqdm(total=int(args.num_episodes)) as pbar:
        for i_episode in range(int(args.num_episodes)):
            if args.env_name not in ["CliffWalking-v0", "MountainCar-v0"] or already_done == True:
                state = env.reset()[0]

            already_done = False


            for i_step in range(args.episode_steps_train):
                action = agent.take_action(discretize_state(state, args.env_name), i_step, is_train=True)
                next_state, reward, done = take_env_step(env, action, already_done, state, args.env_name)
                if done:
                    if "FrozenLake" in args.env_name:
                        next_state = env.reset()[0]
                    elif args.env_name in ["CliffWalking-v0", "CartPole-v1", "InvertedPendulum-v4", "MountainCar-v0"]:
                        already_done = True
                agent.update(discretize_state(state, args.env_name), action, normalize_reward(reward, args), discretize_state(next_state, args.env_name), i_step)
                state = next_state


            agent.update_qv()

            if i_episode > 0 and (i_episode % args.eval_frequency == 0):
                eval_mean, eval_std = evaluate(agent, env, args)
                return_list.append([eval_mean, eval_std])
                if eval_mean - eval_std > best_results:
                    best_results = eval_mean - eval_std
                    best_model = agent.__dict__
                pbar.set_postfix({
                    'mean': '%.3f' % np.mean(np.array(return_list)[-10:,0]),
                    'std': '%.3f' % np.mean(np.array(return_list)[-10:, 1])
                })
            pbar.update(1)

    return return_list, best_model





# python train.py --num_episodes 3000 --num_eval_trajectory 50 --rho 0 --env_name "CartPole-v1" --agent_type "orlc"
# python train.py --num_episodes 3000 --num_eval_trajectory 50 --rho 0 --env_name "CliffWalking-v0" --agent_type "orlc"


STATE_DIM = {
    "CartPole-v1": 6 ** 4,
    "InvertedPendulum-v4": 160,
    "MountainCar-v0": 340
}
ACTION_DIM = {
    "CartPole-v1": 2,
    "InvertedPendulum-v4": 7,
    "MountainCar-v0": 3
}

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument('--num_episodes', type=int, default=3000, help='number of episodes to run')
    parser.add_argument('--episode_steps_train', type=int, default=100, help='number of steps to run in each episode')
    parser.add_argument('--episode_steps_test', type=int, default=100, help='number of steps to run in each episode')
    parser.add_argument('--eval_frequency', type=int, default=20, help='frequency of running evaluation during training')
    parser.add_argument('--num_eval_trajectory', type=int, default=50, help='number of trajectory runs in each evaluation')
    parser.add_argument('--random_seed', type=int, default=0, help='random seed')
    parser.add_argument('--alpha', type=float, default=0.1, help='learning rate for q learning')
    parser.add_argument('--gamma', type=float, default=1, help='discount factor for reward')
    parser.add_argument('--epsilon', type=float, default=0.01, help='epsilon greedy')
    parser.add_argument('--p', type=float, default=0, help='random attack rate')
    parser.add_argument('--rho', type=float, default=0, help='random attack rate')
    parser.add_argument('--iota', type=float, default=1, help='')
    parser.add_argument('--const', type=float, default=100000, help='')
    parser.add_argument('--perturb_type', type=str, default="random", help='choose from ["fix", "random"]')
    parser.add_argument('--agent_type', type=str, default="arrlc", help='choose from ["arrlc", "orlc", "rq"]')
    parser.add_argument('--env_name', type=str, default="CliffWalking-v0", help='choose from ["Taxi-v3", "FrozenLake-v1", "CartPole-v1", "CliffWalking-v0"]')
    args = parser.parse_args()

    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    env = get_env(args)
    if args.env_name in STATE_DIM:
        state_dim = STATE_DIM[args.env_name]
        action_dim = ACTION_DIM[args.env_name]
    else:
        state_dim = env.observation_space.n
        action_dim = env.action_space.n
    agent = get_agent(args, state_dim, action_dim)
    args.state_dim = state_dim

    return_list, best_model = train(agent, env, args)

    output_path = f"{args.agent_type}_{args.env_name}_{args.perturb_type}_rho{args.rho}_p{args.p}_{args.epsilon}_{args.gamma}_{args.num_episodes}_{args.const}"
    output_path = os.path.join("output", output_path)
    if not os.path.exists(output_path):
        os.makedirs(output_path)
    with open(os.path.join(output_path, "eval.npy"), "wb") as f:
        np.save(f, np.array(return_list))

    if args.agent_type == "arrlc":
        save_arrlc_model(best_model, output_path)

    episodes_list = [i * args.eval_frequency for i in list(range(len(return_list)))]
    plt.plot(episodes_list, np.array(return_list)[:,0])
    plt.xlabel('Episodes')
    plt.ylabel('Returns')
    plt.title(f'{args.agent_type} on {args.env_name}')
    plt.show()